iT邦幫忙

2021 iThome 鐵人賽

DAY 18
0
AI & Data

Attention到底在關注什麼?系列 第 18

Day 18 self-attention的實作準備(四) keras的compile和fit

  • 分享至 

  • xImage
  •  

前言

昨天講到要如何建立model,今天來講要如何訓練以及預測

編譯模型

建立完模型之後,必須呼叫compile()方法來指定損失函式與優化法(optimizer)

model.compile(loss="sparse_categorical_crossentropy",
              optimizer="sgd",
              metrics=["accuracy"])

這段程式需要解釋一下。首先,我們使用"sparse_categorical_crossentropy" loss的原因是因為標籤是類別,也就是每一個標籤彼此並不相關。如果是另一個情況,每一個實例都有每個類別的目標機率(例如one-hot向量,比如說[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.]代表3),我們就改用"categorical_crossentropy" loss。如果是在做二元分類,輸出層就要使用sigmoid激活函數,而不是"softmax"激活函數,並且要使用"binary_crossentropy" loss

將optimizer設為sgd代表我們要使用簡單的隨機梯度下降來訓練模型。換句話說,Keras會執行反向傳播法。

最後因為這個是分類器,因此使用 精確度(accuracy) 來評估可以很明顯的看出結果。

訓練與評估模型

這邊我們只需要呼叫模型的fit()方法可以訓練它了:

history = model.fit(X_train, y_train, epochs=30,
                    validation_data=(X_valid, y_valid))

將輸入(X_train)以及目標類別(y_train)傳給它,以及訓練的epoch數,並且傳入驗證組
keras會在epoch結束的時候,使用驗證組來評估損失與其它的指標,可以讓你了解模型訓練的狀況

fit()方法會回傳一個History物件,裡面有訓練參數(history.params)、它經歷的epoch串列(history.epoch),以及一個字典(history.history),裡面有處理訓練組與驗證組的每個epoch完畢時的損失以及其他指標。利用這個字典來建立一個pandas DataFrame並且呼叫plot()方法

import pandas as pd
import matplotlib.pyplot as plt

pd.DataFrame(history.history).plot(figsize=(8, 5))
plt.grid(True)
plt.gca().set_ylim(0, 1)
plt.show()

用模型來預測

接著我們可以使用模型的predict()方法來對test資料做預測
這邊使用測試資料的前三個資料

X_new = X_test[:3]
y_proba = model.predict(X_new)
y_proba.round(2)


這邊可以看到模型計算出來各個類別的機率,從類別0到類別9。

y_pred = np.argmax(y_proba, axis=1)
y_new = y_test[:3]
print(y_pred)
print(y_new)

從這邊可以看到訓練出來的結果和正確答案一樣

參考資料


上一篇
Day 17 self-attention的實作準備(三) 測試使用keras來建立模型
下一篇
Day 19 利用transformer自己實作一個翻譯程式(一)
系列文
Attention到底在關注什麼?30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言